from jqc import Parse,jqc_plot
import time
import numpy
import scipy.constants
import matplotlib.pyplot as pyplot
import warnings
from sympy.physics.wigner import wigner_3j
import os
import scipy.interpolate as interp
import matplotlib.gridspec as gridspec
import matplotlib.patches as patches
from matplotlib.transforms import TransformedBbox
from mpl_toolkits.axes_grid1.inset_locator import BboxPatch, BboxConnector,BboxConnectorPatch


cwd = os.path.dirname(os.path.abspath(__file__))

############# Calculation of errors requires at least ~40s per selection above
calculate_errors = True

#Universal constants
pi = numpy.pi
c =scipy.constants.c
eps0 = scipy.constants.epsilon_0
bohr = scipy.constants.physical_constants['Bohr radius'][0]
h = scipy.constants.h

conv_Debye_SI = 3.336e-30 #conversion factor, can use dipole moment in Debye
#setup the plot style
jqc_plot.plot_style('normal')
colours = pyplot.rcParams['axes.prop_cycle'].by_key()['color']

# define the best-fit values of the polarizability
a0 = 4*pi*eps0*2020*bohr**3
a2 = 4*pi*eps0*1997*bohr**3

#and their errors
err_a0 = 4*pi*eps0*20*bohr**3
err_a2 = 4*pi*eps0*6*bohr**3
#Permanent dipole moment
D0 = conv_Debye_SI*1.225 #RbCs

#D0 = conv_Debye_SI*2.7 #NaK

#rotational (and distortion) constants for RbCs
B0 = h *490.173994e6
Dv = h *213

#Experimental Parameters
WattsPerVolt = 0.985 #calibration of DT power before cell
atten = 0.926645296 #transmission at magic angle
Ts = 0.781 #transmission of s-polarised light through cell
Tp = 0.99468 #transmission of p-polarisation

n = time.time() #current time
#Generate the Hamiltonian from the mathematica solution, returns an array of
#lambda-esque functions

Hamiltonian = Parse.Generate_Hamiltonian_Stark()
m = time.time()
print("Generating took:",m-n,"s")

def evaluate(Ham,vars,Phys_const,mol_const):
    ''' This function steps through the elements of the Hamiltonian and returns
    the elements as floating point numbers because you cannot pass arguments
    element-wise to an array. '''
    mol = numpy.zeros(Ham.shape)
    for N in range(len(Hamiltonian[:,0])):
        for M in range(len(Hamiltonian[0,:])):
            mol[N,M] = Ham[N,M](*Phys_const,*vars,*mol_const)
    return mol

def DipoleMoment(state1,state2):
    ''' Calculates the transition dipole moment of two states in the |N,MN>
    basis'''
    N1,M1 = state1
    N2,M2 = state2
    Wigner1 = numpy.float64(wigner_3j(N1,1,N2,-M1,(M1-M2),M2))
    Wigner2 = numpy.float64(wigner_3j(N1,1,N2,0,0,0))
    return D0*numpy.sqrt((2*N1+1)*(2*N2+1))*(-1)**M1*Wigner1*Wigner2

def DDI(d0,d1,r):
    ''' Works out the dipole-dipole interaction at a specific distance, r, from
    each other. assumes aligned dipoles. differs by a factor of 2 from the
    maximum dipole-dipole interaction'''
    return d0*d1/(4*pi*eps0*r**3)

def chi2(energylevel,data):
    ''' calulates the chi-squared value for a given energy level and variable
    against experimental data '''
    energyfn = interp.interp1d(*energylevel)
    chi2 = (data[:,1]-energyfn(data[:,0]))**2/data[:,2]**2
    return numpy.sum(chi2)

#parameters for the model
I0 = Tp*(2*1.6*WattsPerVolt)/(pi*172.6e-6**2) #W/m^2
Beta = numpy.deg2rad(54.7356)
beta = numpy.rad2deg(Beta)

#maximum values for variation
BetaMax = numpy.deg2rad(90+15)
BetaMin = numpy.deg2rad(-15)
Imax = 20e7
Emax =10000e2 #V/m
steps = 1500 #number of values to calculate, 1500 takes 20 s per run

#maximum number of rotational states to PLOT; to N=8 is included in calculation
Nmax = 1
Nmin = 0

Nmax = int(numpy.min([Nmax,8])) #determines whether the user has entered nMax>8

#creates a series of values that count the states in our basis
States = [(N,MN) for N in range(0,9) for MN in range(-N,N+1)]

#matrix of transition dipole elements
DipoleMatrix = numpy.zeros((len(States),len(States)))

for i,S1 in enumerate(States):
    for j,S2 in enumerate(States):
        DipoleMatrix[i,j]=DipoleMoment(S1,S2)
#from the states list can create a list of Kets to identify each line
labels = ["|"+str(States[i][0])+","+str(States[i][1])+"$\\rangle$"
            for i in range(len(States))]

#maximum number to identify states in plot
Statemax = numpy.sum([2*x+1 for x in range(0,Nmax+1)])

#prepare for the calculation
Energy = numpy.zeros((Hamiltonian[:,0].shape[0],steps))
if calculate_errors:
    #create a duplicate array for holding the error on the energy
    Energy_err = numpy.zeros(Energy.shape)
#statemax*statemax array for eigenstates
EigStates = numpy.zeros((Hamiltonian[:,0].shape[0],Hamiltonian.shape[1],steps))

m = time.time()
Dipole = numpy.zeros((Energy.shape))
x = [75,65,27,0]
y= [-20,-145,75,-145]
state = ["$|0,0\\rangle$","$|1,-1\\rangle$","$|N=1,M_N=1\\rangle$","$|1,0\\rangle$"]
colours =[jqc_plot.colours['green'],jqc_plot.colours['purple'],
            jqc_plot.colours['blue'],jqc_plot.colours['red']]

grid = gridspec.GridSpec(3,2,width_ratios=[2,1])

fig = pyplot.figure("Efield")

ax_beta = fig.add_subplot(grid[:,0])

Intensity = [True,True,True]
Polarisation = [True,False,False]
axislabel = ["(i)","(ii)","(iii)"]

for j,Efield in enumerate([300e2,150e2,100e2]):
    #evaluate the zero-point for transition frequency/ energy level shifts
    H0 = evaluate(Hamiltonian,(Efield,0,0),(eps0,c),(a0,a2,B0,Dv,D0))
    Energy0=numpy.sort(numpy.linalg.eigvals(H0))
    if Intensity[j]:
        print("Intensity calculation at E ={:.1f} V/cm".format(Efield*1e-2))
        try:
            ax = fig.add_subplot(grid[j,1],sharex=ax)
        except NameError:
            ax = fig.add_subplot(grid[j,1])

        #varying intensity with fixed polarisation and electric field
        for i,IDT in enumerate(numpy.linspace(0,Imax,steps)):
            #evaluate over each intensity
            H = evaluate(Hamiltonian,(Efield,IDT,Beta),(eps0,c),(a0,a2,B0,Dv,D0))
            with warnings.catch_warnings():
                #ignore the warning that casting to real ignores complex part as
                #energy eigenvalues should all be real
                warnings.filterwarnings("ignore",category=numpy.ComplexWarning)
                Eigen = numpy.linalg.eig(H)
                order =numpy.argsort(Eigen[0])
                Energy[:,i] = Eigen[0][order]
                EigStates[:,:,i] = Eigen[1][:,order]
            if calculate_errors:
                #calculate the errors using the functional approach to errors
                #analysis, err ~= |f(A+err,B)-f(A,B)|+|f(A,B+err)-f(A,B)|
                Herr_a0 = evaluate(Hamiltonian,(Efield,IDT,Beta),(eps0,c),
                                    (a0+err_a0,a2,B0,Dv,D0))
                Herr_a2 = evaluate(Hamiltonian,(Efield,IDT,Beta),(eps0,c),
                                    (a0,a2+err_a2,B0,Dv,D0))
                #error in energy levels due to alpha0, alpha2
                Energy_err_a0 = numpy.abs(numpy.sort(
                                numpy.linalg.eigvals(Herr_a0))-Eigen[0][order])
                Energy_err_a2 = numpy.abs(numpy.sort(
                                numpy.linalg.eigvals(Herr_a2))-Eigen[0][order])

                Energy_err[:,i] =numpy.sqrt(Energy_err_a0**2+Energy_err_a2**2)

        Laser = numpy.linspace(0,Imax,steps)
        index = 1
        for i in range(1,Statemax):
            #plotting loop
            Trans= Energy0[i]-Energy0[0] #transition energy at 0 intensity


            #determine labelling from states at zero intensity
            #index = numpy.argmax(EigStates[:,i,0]**2)
            #plotting ax_beta is energy
            #ax_b2 is shifts
            p = ax.plot(Laser*1e-7,
                        1e-3*((Energy[i,:]-Energy[0,:])-(Energy0[i]-Energy0[0]))/h,
                        label = labels[index],color=colours[index])
            if calculate_errors:
                #plot the error as a translucent region either side of the best fit
                error = numpy.sqrt(Energy_err[i,:]**2+Energy_err[0,:]**2)
                ax.fill_between(Laser*1e-7,
                1e-3*((Energy[i,:]-Energy[0,:])-(Energy0[i]-Energy0[0])-error)/h,
                1e-3*((Energy[i,:]-Energy[0,:])-(Energy0[i]-Energy0[0])+error)/h,
                alpha=0.5,zorder=1.0+0.05*i,color=p[0].get_color(),
                lw=0)
            index+=1
        #labelling axes

        try:
            #try and open the file, if it exists run this loop if not tell the user
            file = numpy.genfromtxt(
                    cwd+"\\EXP Data\\E={:.0f} Vcm\\Intdata.csv".format(Efield*1e-2),
                    delimiter=',')

            P = file[:,0]*WattsPerVolt*atten
            P = 1e-7*(P*2)/(pi*172.6e-6**2)
            ax.errorbar(P,1e3*(file[:,1]-file[0,1]),yerr=1e3*file[:,2],color='k',fmt='o')
            file[:,1]=1e3*(file[:,1]-file[0,1])
            file[:,2]=file[:,2]*1e3
            file[:,0]=P

            print("Intensity chi-squared:",chi2(
            (Laser*1e-7,1e-3*((Energy[3,:]-Energy[0,:])-(Energy0[3]-Energy0[0]))/h),
            file)/len(file[:,1]))

        except OSError:
            #no file found
            print("No data found")

        ax.set_xlim(0,10)
        ax.set_ylim(-10,70)

        ax.text(0.98,0.8,axislabel[j],fontsize=20,transform=ax.transAxes,
                horizontalalignment='right')
        if j ==0:
            ax.text(0.02,0.8,"(b)",transform=ax.transAxes,fontsize=20)
        if j !=2:
            pyplot.setp(ax.get_xticklabels(),visible=False)

    if Polarisation[j]:
        print("Polarisations calculation at E ={:.1f} V/cm".format(Efield*1e-2))
        #Varying polarisation with fixed intensity and electric field
        for i,B in enumerate(numpy.linspace(BetaMin,BetaMax,steps)):
            #evaluate over each polarisation angle
            H = evaluate(Hamiltonian,(Efield,I0,B),(eps0,c),(a0,a2,B0,Dv,D0))
            with warnings.catch_warnings():
                warnings.filterwarnings("ignore",category=numpy.ComplexWarning)
                Eigen = numpy.linalg.eig(H)
                order =numpy.argsort(Eigen[0])
                Energy[:,i] = Eigen[0][order]
                EigStates[:,:,i] = Eigen[1][:,order]
            if calculate_errors:
                #calculate the errors using the functional approach to errors
                Herr_a0 = evaluate(Hamiltonian,(Efield,I0,B),(eps0,c),
                                    (a0+err_a0,a2,B0,Dv,D0))
                Herr_a2 = evaluate(Hamiltonian,(Efield,I0,B),(eps0,c),
                                    (a0,a2+err_a2,B0,Dv,D0))

                Energy_err_a0 = numpy.abs(numpy.sort(
                                numpy.linalg.eigvals(Herr_a0))-Eigen[0][order])
                Energy_err_a2 = numpy.abs(numpy.sort(
                                numpy.linalg.eigvals(Herr_a2))-Eigen[0][order])

                Energy_err[:,i] =numpy.sqrt(Energy_err_a0**2+Energy_err_a2**2)

        Pol = numpy.linspace(numpy.rad2deg(BetaMin),numpy.rad2deg(BetaMax),steps)
        index = 0
        for i in range(0,Statemax):
            Trans= Energy0[i]-Energy0[0]

            p=ax_beta.plot(Pol,
                       1e-3*((Energy[i,:]-Energy[0,:])-(Energy0[i]-Energy0[0]))/h,
                       zorder=1.5+0.05*i,color=colours[index])

            ax_beta.text(x[index],y[index],state[index],color=p[0].get_color())

            if calculate_errors:
                #plot the error as a translucent region either side of the best fit
                error = numpy.sqrt(Energy_err[i,:]**2+Energy_err[0,:]**2)
                ax_beta.fill_between(Pol,
                1e-3*((Energy[i,:]-Energy[0,:])-(Energy0[i]-Energy0[0])-error)/h,
                1e-3*((Energy[i,:]-Energy[0,:])-(Energy0[i]-Energy0[0])+error)/h,
                alpha=0.5,zorder=1.0+0.05*i,color=p[0].get_color(),lw=0)
            index +=1
        ax_beta.set_xlabel("Polarisation angle, $\\beta$ ($^\\circ$)")
        ax_beta.set_ylabel("Transition Frequency Shift (kHz)")
        try:
            file = numpy.genfromtxt(
                cwd+"\\EXP Data\\E={:.0f} Vcm\\Betadata.csv".format(Efield*1e-2),
                delimiter=',')
            tp = numpy.sqrt(Tp)
            ts = numpy.sqrt(Ts)

            angles = file[:,0]
            angles = 2*(angles - 132.25)

            angles_cell = numpy.rad2deg(
                            numpy.arctan((tp/ts)*numpy.tan(numpy.deg2rad(angles))))

            eb = ax_beta.errorbar(angles_cell[1:],1e3*(file[1:,1]-file[0,1]),yerr=file[1:,2],
                            color='k',fmt='o')
            file[:,1]=(file[:,1]-file[0,1])
            file[:,2] = file[:,2]
            file[:,0]=angles_cell
            print("Polarisation chi-squared:",
                    chi2((Pol,
                    1e-6*((Energy[3,:]-Energy[0,:])-(Energy0[3]-Energy0[0]))/h),
                    file[1:,:])/len(file[1:,1]))

        except OSError:
            print("No data found")

        ax_beta.set_xlim(-5,100)
        ax_beta.set_xticks([0,45,90])
        ax_beta.set_ylim([-150,100])

        p1 = patches.ConnectionPatch(xyA=(beta,1e3*(file[1,1]-file[0,1])),
        xyB=(0,0),coordsA='data',coordsB='axes fraction',
        axesA=ax_beta,axesB=ax,color=jqc_plot.colours['grayblue'])
        ax_beta.add_patch(p1)

        p2 = patches.ConnectionPatch(xyA=(beta,1e3*(file[1,1]-file[0,1])),
        xyB=(0,1),coordsA='data',coordsB='axes fraction',
        axesA=ax_beta,axesB=ax,color=jqc_plot.colours['grayblue'])
        ax_beta.add_patch(p2)


ax_beta.text(0.01,0.93,"(a)",transform=ax_beta.transAxes,fontsize=20)
ax.text(-0.5,1.5,"Transition Frequency Shift (kHz)",rotation=90,fontsize = 15,
            verticalalignment = 'center',transform = ax.transAxes)
ax.set_xlabel("Intensity (kW$\\,$cm$^{-2}$)")

ax.arrow(1.08,0,0,3,length_includes_head=True,transform=ax.transAxes,
        clip_on=False,width=.02,color='k')

ax.text(1.11,1.5,"Increasing $E$",rotation = 90, fontsize=15,
        verticalalignment='center',transform=ax.transAxes)

y0,y1 =ax_beta.get_ylim()


ax_beta.plot([beta,beta],[y0,y1],ls='--',color='k')

fig.tight_layout()

fig.subplots_adjust(hspace=0,wspace=0.41)

l = time.time()
print("Script took:",l-m,"s")
fig.savefig("MAGIC.pdf")
fig.savefig("MAGIC.png")
pyplot.show()
